"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


Create plot of the clustering of the mice dataset (Fig. 6).

saves clustering as `multi_weeks_partitions.pickle` 
and `wild_mice_flow_stability_clustering.json`


"""

import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))



import numpy as np
import pandas as pd
from TemporalNetwork import ContTempNetwork
from TemporalStability import Partition


import matplotlib.pyplot as plt

import matplotlib.dates as mpldates
from matplotlib.colors import ListedColormap


from datetime import timedelta

color_list =["#ffffff",
"#4ba706",
"#a2007e",
"#806dcb",
"#5eb275",
"#ca3b01",
"#01a4d6",
"#b77600",
"#a39643",
"#cc6ea9",
"#1e5e39",
"#cb5b5a"]

cmap = ListedColormap(color_list)

cmap_now = ListedColormap(color_list[1:])

plt.style.use('alex_paper')

import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 1.7


datadir = '../paper_data/mice_data_march_april2017/'

clusterdir = os.path.join(datadir, 'clustersI')

figdir = '../figures/micenet_march_april2017'

file_prefix = 'micenet2007'

net_file = os.path.join(datadir,'micenet_2017_02_28_to_2017_05_01')


raise Exception

#%%


files = os.listdir(clusterdir)

interval_starts = set()
interval_ends = set()
tau_ws = set()
for file in files:
    extracts = os.path.splitext(file)[0].split('_')
    interval_starts.add(int(extracts[-3]))
    interval_ends.add(int(extracts[-1]))
    for extract in os.path.splitext(file)[0].split('_'):
        if extract.startswith('w'):
            tau_ws.add(float(extract[1:]))
            
            
            
#%%% net work file

net = ContTempNetwork.load(net_file, attributes_list=['node_to_label_dict',
                      'events_table',
                      'times',
                      'time_grid',
                      'num_nodes',
                      'start_date',
                      'male_array',
                      'female_array',])


    
#%% 

num_intervals = len(interval_ends)

t0 = 0

tend = net.times[-1]

slice_length = 60*60 #1h

t= t0
time_slices = [t0]
while t <= tend:
    t += slice_length
    time_slices.append(t)
    


intervals = sorted(list(interval_starts))
            
int_times = [time_slices[i] for i in intervals]    
    

#%% plot datetimes
        
datetimes = [net.start_date + timedelta(seconds=t) for t in net.times]
datetimes_slices = pd.date_range(start=net.start_date, end=net.start_date + timedelta(seconds=net.times[-1]),
                                 freq='1H')
datetimes_intervals = [datetimes_slices[i].to_pydatetime() for i in \
                       intervals]

fig, ax = plt.subplots(1,1)
ax.plot(datetimes, range(len(datetimes)))

ax.vlines(datetimes_intervals, ymin=0, ymax=len(datetimes))

for x, text in zip(datetimes_intervals, intervals):
    ax.text(x, 0, str(text))

ax.xaxis.set_minor_locator(mpldates.DayLocator())


    
#%% sort clusters
        
def sort_clusters(cluster_list_to_sort, cluster_list_model, thresh_ratio=0.3):
        
    clust_similarity_lists = []
    for clust in cluster_list_to_sort:
        jaccs = []
        for class_clust in cluster_list_model:
            jaccs.append(len(clust.intersection(class_clust))/len(clust.union(class_clust)))    
        clust_similarity_lists.append(jaccs)
        
    #now sort
    clust_similarity_matrix = np.array(clust_similarity_lists)
    new_clust_order = []
    all_clusts = list(range(clust_similarity_matrix.shape[0]))
    
    zero_clusts = (clust_similarity_matrix.sum(1) == 0).nonzero()[0].tolist()
    for z in zero_clusts:
        all_clusts.remove(z)
    
    while len(new_clust_order) < len(cluster_list_to_sort) - len(zero_clusts):
        for cla in range(clust_similarity_matrix.shape[1]):
            # loop on classes and sort according to most similar
            
            sorted_comms = clust_similarity_matrix[all_clusts,cla].argsort()[::-1]
            scores = clust_similarity_matrix[all_clusts,cla][sorted_comms]
            if scores.max() > 0:
                scores /= scores.max()
                for c, s in zip(sorted_comms,scores):
                    if s >= thresh_ratio and all_clusts[c] not in new_clust_order:
                            new_clust_order.append(all_clusts[c])
                        
        # update all_clusts
        for n in new_clust_order:
            if n in all_clusts:
                all_clusts.remove(n)
                
    return [cluster_list_to_sort[i] for i in new_clust_order + zero_clusts]




  
#%% make data for multiple flows

multi_res = {}

taus_to_plot = [0.1, 1.0, 8.0, 60.0, 3600.0, 86400.0, 604800.0, 4234000.0]

taus_to_plot = tau_ws

for tau_w in taus_to_plot:
    
    multi_res[tau_w] = {}
        
    multi_res[tau_w]['partitions_back'] = []
    multi_res[tau_w]['partitions_forw'] = []
    multi_res[tau_w]['nvarinfs_forw'] = []
    multi_res[tau_w]['nvarinfs_back'] = []
    
    best_partrev = None
    
    for int_start, int_stop in [[0, 168], [168, 336], [336, 504], [504,672], 
                                [672,840], [840,1008], [1008,1176], [1176,1344], [1344,1487]]:
        
        # find active nodes during this time:
        active_nodes = set(net.events_table.loc[np.logical_and(\
                       net.events_table.starting_times >= time_slices[int_start], 
                       net.events_table.ending_times < time_slices[int_stop])].source_nodes.tolist())
            
        active_nodes.update(net.events_table.loc[np.logical_and(\
                       net.events_table.starting_times >= time_slices[int_start], 
                       net.events_table.ending_times < time_slices[int_stop])].target_nodes.tolist())
    
        
        
        
        clustres = pd.read_pickle(os.path.join(clusterdir, 'clusters_' + \
                                      file_prefix + '_tau_w' + \
                                      '{0:.3e}_I_lin_{1:06d}_to_{2:06d}.pickle'.format(tau_w, 
                                                                                     int_start,
                                                                                     int_stop)))
        clustres_rev = pd.read_pickle(os.path.join(clusterdir, 'clusters_' + \
                                      file_prefix + '_tau_w' + \
                                      '{0:.3e}_I_lin_{1:06d}_to_{2:06d}.pickle'.format(tau_w, 
                                                                                     int_stop,
                                                                                     int_start)))
        print(clustres['num_repeat'])
        print(clustres_rev['num_repeat'])
        
        # partitions with only active nodes
        best_part = Partition(len(active_nodes), [clust.intersection(active_nodes) for \
                                              clust in clustres['best_cluster_sym']])
        if best_partrev is None:
            # sort according to cluster size
            best_part.cluster_list = sorted(best_part.cluster_list, key=lambda x: len(x), reverse=True)
        else:
            # sort according to last partrev
            best_part.cluster_list = sort_clusters(best_part.cluster_list, 
                                                   best_partrev.cluster_list)
        
        best_partrev = Partition(len(active_nodes), 
                                 sort_clusters([clust.intersection(active_nodes) for \
                                                clust in clustres_rev['best_cluster_sym']], 
                                                best_part.cluster_list))
        
        multi_res[tau_w]['partitions_forw'].append(best_part)
        multi_res[tau_w]['partitions_back'].append(best_partrev)
        
        multi_res[tau_w]['nvarinfs_forw'].append(clustres['nvarinf_sym'])
        multi_res[tau_w]['nvarinfs_back'].append(clustres_rev['nvarinf_sym'])
        

#%%
sex_dict = {'male' : set(net.node_array[net.male_array]),
            'female' : set(net.node_array[net.female_array])}
sex_dict['unknown'] = set(net.node_array).difference(sex_dict['male']).difference(sex_dict['female'])


multi_res['sex_dict'] = sex_dict

pd.to_pickle(multi_res, os.path.join(datadir, 'multi_weeks_partitions.pickle'))

#%% save results for sharing
multi_res_share = []
for tau_w in tau_ws:
    multi_res_share.append({'tau_w':tau_w,
                           'partitions_forward_per_week': 
                               [p.cluster_list for p in multi_res[tau_w]['partitions_forw']],
                           'partitions_backward_per_week': 
                                   [p.cluster_list for p in multi_res[tau_w]['partitions_back']],
                           'NVI_forward_per_week': multi_res[tau_w]['nvarinfs_forw'],
                           'NVI_backward_per_week': multi_res[tau_w]['nvarinfs_back'],
                            })
multi_res_share.append(sex_dict)
import json

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, set):
            return list(obj)
        return super(NpEncoder, self).default(obj)


with open('../data/wild_mice_flow_stability_clustering.json', 'w') as fopen:
    json.dump(multi_res_share, fopen, cls=NpEncoder)
    
#%% avg + std NVI


# taus_to_plot = sorted(list(tau_ws))
taus_to_plot = [0.1, 1.0, 8.0, 60.0, 3600.0, 86400.0, 604800.0, 4234000.0]

avg_NVI = [np.mean(multi_res[tau_w]['nvarinfs_forw'] + \
                   multi_res[tau_w]['nvarinfs_back']) for tau_w in taus_to_plot]
std_NVI = [np.std(multi_res[tau_w]['nvarinfs_forw'] + \
                   multi_res[tau_w]['nvarinfs_back']) for tau_w in taus_to_plot]
    
avg_csize = [np.mean([np.mean([len(c) for c in p.cluster_list if len(c)>1]) for \
                         p in multi_res[tau_w]['partitions_forw']] + \
                    [np.mean([len(c) for c in p.cluster_list if len(c)>1]) for \
                         p in multi_res[tau_w]['partitions_back']]) for \
                                 tau_w in taus_to_plot]

std_csize = [np.std([np.mean([len(c) for c in p.cluster_list if len(c)>1]) for \
                         p in multi_res[tau_w]['partitions_forw']] + \
                    [np.mean([len(c) for c in p.cluster_list if len(c)>1]) for \
                         p in multi_res[tau_w]['partitions_back']]) for \
                                 tau_w in taus_to_plot]    
    
    
    
fig, (ax1,ax2) = plt.subplots(2,1, sharex=True)

    
ax1.errorbar(taus_to_plot, avg_NVI, yerr=std_NVI, fmt='o-', color=color_list[2])

ax2.errorbar(taus_to_plot, avg_csize, yerr=std_csize, fmt='o-', color=color_list[2])

ax1.set_xscale('log')


ax2.set_xlabel(r'$\tau_w$ [s]')
ax1.set_ylabel('Norm. Var. Inf.')
ax2.set_ylabel('Avg. group size')

# plt.savefig(os.path.join(figdir, 'avg_std_NVI_weeks.png'),dpi=600)
# plt.savefig(os.path.join(figdir, 'avg_std_NVI_weeks.pdf'))


# pd.to_pickle({'taus_to_plot': taus_to_plot,
#               'avg_NVI': avg_NVI,
#               'std_NVI': std_NVI,
#               'avg_csize' : avg_csize,
#               'std_csize' : std_csize},
#              os.path.join(datadir, 'mice_nvi_plot_data.pickle')
#              )


#%% make data for flow diagram

# tau_w = 1.0
# tau_w = 60.0
tau_w = 86400.0
flows = []
#cluster_infos = []

weeks = []
for i in range(len(multi_res[tau_w]['partitions_forw'])):
    weeks.extend(['wf'+str(i+1), 'wb'+str(i+1)])

sex_dict = {'male' : set(net.node_array[net.male_array]),
                'female' : set(net.node_array[net.female_array])}
sex_dict['unknown'] = set(net.node_array).difference(sex_dict['male']).difference(sex_dict['female'])
    
partitions = []
for w in range(len(multi_res[tau_w]['partitions_forw'])):
    partitions += [multi_res[tau_w]['partitions_forw'][w], 
                   multi_res[tau_w]['partitions_back'][w]]

#sort first partition
if tau_w == 1.0:
    c3 = partitions[0].cluster_list[3]
    c7 = partitions[0].cluster_list[7]
    partitions[0].cluster_list[7] = c3
    partitions[0].cluster_list[3] = c7
    
past_clist = partitions[0].cluster_list

past_w = weeks[0]
past_nodes = set.union(*past_clist)

for  w, part in zip(weeks[1:],partitions[1:]):
    new_clist = sort_clusters(part.cluster_list, past_clist)
    
    for clas, clas_set in sex_dict.items():
        for s, comm_s in enumerate(past_clist):
            for t, comm_t in enumerate(new_clist):
                val = len(clas_set.intersection(comm_s).intersection(comm_t))
                if val > 0:
                    flows.append({'source': past_w+f'_{s:02d}', 
                                  'target': w+f'_{t:02d}', 'type': clas, 'value': val})
    
    new_nodes = set.union(*new_clist)
    
    if new_nodes != past_nodes:
        arriving_nodes = new_nodes.difference(past_nodes)
        leaving_nodes = past_nodes.difference(new_nodes)
    
        for clas, clas_set in sex_dict.items():
            #arriving nodes
            for t, comm_t in enumerate(new_clist):
                val = len(clas_set.intersection(arriving_nodes).intersection(comm_t))
                if val > 0:
                    flows.append({'source': 'ext', 
                                  'target': w+f'_{t:02d}', 'type': clas, 'value': val})
    
            for s, comm_s in enumerate(past_clist):
                #leaving nodes
                val = len(clas_set.intersection(leaving_nodes).intersection(comm_s))
                if val > 0:
                    flows.append({'source': past_w+f'_{s:02d}', 
                                  'target': 'ext', 'type': clas, 'value': val})   
     
    past_clist = new_clist
    past_w = w
    past_nodes = new_nodes
    
          
df_flows = pd.DataFrame.from_dict(flows)
df_flows.to_csv(os.path.join(datadir, f'flows_multiple_tau_w={tau_w}.csv'))


#%% make sankey diagram

import floweaver as flo

# list of clusters per week forward and backward

forward_part_dict = {'wf' + str(i) : sorted(df_flows.loc[df_flows.source.str.startswith('wf' + str(i))].source.unique().tolist()) for i in range(1,10)}
backward_part_dict = {'wb' + str(i) : sorted(df_flows.loc[df_flows.target.str.startswith('wb' + str(i))].target.unique().tolist()) for i in range(1,10)}

class_list = sorted(df_flows.type.unique().tolist())


nodes = {}
ordering = []
bundles = []
lastpart = None

# fwpart_titles = {fwpart : d.strftime('%b %w') for fwpart, d in zip(sorted(forward_part_dict.keys()), datetimes_intervals)}

for w, (fwpart, bwpart) in enumerate(zip(sorted(forward_part_dict.keys()), sorted(backward_part_dict.keys()))):
    nodes[fwpart] = flo.ProcessGroup(selection=forward_part_dict[fwpart],
                                 partition=flo.Partition.Simple('process',forward_part_dict[fwpart]),
                                 title=datetimes_intervals[w].strftime('%b %-d') + ' Forw.')
    nodes[bwpart] = flo.ProcessGroup(selection=backward_part_dict[bwpart],
                                 partition=flo.Partition.Simple('process',backward_part_dict[bwpart]),
                                 title=datetimes_intervals[w+1].strftime('%b %-d')+ ' Back.')
    ordering.append([fwpart])
    ordering.append([bwpart])
    
    bundles.append(flo.Bundle(fwpart,bwpart))
    if lastpart is not None:
        bundles.append(flo.Bundle(lastpart,fwpart))
        
    lastpart = bwpart
    


class_flow_part = flo.Partition.Simple('type', class_list)

# Set the colours for the labels in the partition.
palette = {cl : col for cl, col in zip(class_list, color_list[1:])}

# New SDD with the flow_partition set
sdd = flo.SankeyDefinition(nodes, bundles, ordering,
                       flow_partition=class_flow_part)



w = flo.weave(sdd, df_flows, palette=palette)




#%% save flow and render svg
import json

json_file = os.path.join(figdir,f'mice_net_weekly_tau_w{tau_w}.json')
svg_file = os.path.join(figdir,f'mice_net_weekly_tau_w{tau_w}.svg')
with open(json_file, 'w') as fopen:
    json.dump(w.to_json(), fopen)
    


os.system(f'svg-sankey --size 3000,1000 --font-size 24 --margins 10,150 {json_file}  > {svg_file}')    

#remove node labels

os.system(f"sed -i 's/>w[b-f][0-9]_[0-9][0-9]<\/text/><\/text/g' {svg_file}")

